All nodes distance k in binary tree [DFS + BFS]

Time: O(N); Space: O(N); medium

We are given a binary tree (with root node root), a target node, and an integer value K.

Return a list of the values of all nodes that have a distance K from the target node. The answer can be returned in any order.

Example 1:

Input: root = {TreeNode} [3,5,1,6,2,0,8,None,None,7,4], target = 5, K = 2

Output: [7,4,1]

Explanation:

  • The nodes that are a distance 2 from the target node (with value 5) have values 7, 4, and 1.

  • Note that the inputs “root” and “target” are actually TreeNodes.

  • The descriptions of the inputs above are just serializations of these objects.

Example 2:

Input: root = {TreeNode} [1,2,3,4], target = 2, K = 1

Output: [1,4]

Explanation:

    1
   / \
  2   3
 /
4
  • The node 1 and 4 is 1 away from 2.

Notes:

  • The given tree is non-empty.

  • Each node in the tree has unique values 0 <= node.val <= 500.

  • The target node is a node in the tree.

  • 0 <= K <= 1000.

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

Auxiliary Tools

[2]:
from graphviz import Digraph

class TreeTasks(object):
    def create_binary_tree(self, nums):
        root = None
        len_nums = len(nums)
        idx = 0
        res = self.insertLevelOrder(nums, root, idx, len_nums)
        return res

    def insertLevelOrder(self, nums, root, idx, len_nums):
        # Base case for recursion
        if idx < len_nums:
            temp = TreeNode(nums[idx])
            root = temp
            # insert left child
            root.left = self.insertLevelOrder(nums, root.left, 2 * idx + 1, len_nums)
            # insert right child
            root.right = self.insertLevelOrder(nums, root.right, 2 * idx + 2, len_nums)
        return root

Solution

[3]:
import collections

class Solution1(object):
    """
    Time: O(N)
    Space: O(N)
    """
    def distanceK(self, root, target, K):
        """
        :type root: TreeNode
        :type target: TreeNode
        :type K: int
        :rtype: List[int]
        """
        def dfs(parent, child, neighbors):
            if not child:
                return
            if parent:
                neighbors[parent.val].append(child.val)
                neighbors[child.val].append(parent.val)
            dfs(child, child.left, neighbors)
            dfs(child, child.right, neighbors)

        neighbors = collections.defaultdict(list)
        dfs(None, root, neighbors)
        bfs = [target.val]
        lookup = set(bfs)

        for _ in range(K):
            bfs = [nei for node in bfs
                   for nei in neighbors[node]
                   if nei not in lookup]
            lookup |= set(bfs)

        return bfs
[4]:
s = Solution1()

root = [3,5,1,6,2,0,8,None,None,7,4]
target = TreeNode(5)
K = 2
t = TreeTasks()
tree = t.create_binary_tree(root)
res = s.distanceK(tree, target, K)
assert [r for r in res if r] == [1, 7, 4] or [7, 4, 1]

root = [1,2,3,4]
target = TreeNode(2)
K = 1
t = TreeTasks()
tree = t.create_binary_tree(root)
res = s.distanceK(tree, target, K)
assert res == [1, 4]